{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 手写数字识别简介\n",
    "\n",
    "*Copyright (c) 2022 Institute for Quantum Computing, Baidu Inc. All Rights Reserved.*\n",
    "\n",
    "计算机视觉（Computer Vision, CV）是指让计算机能够从图像、视频或其它视觉输入中获取有意义的信息。它是人工智能领域中的一个非常基础且重要的组成部分。在 CV 中，手写数字识别（handwritten digit classification）是一个较为基础的任务。它在 MNIST 数据集\\[1\\]上进行训练和测试，用来验证模型是否拥有 CV 方面的基础能力。\n",
    "\n",
    "MNIST 数据集中包含如下图所示的手写数字。MNIST 共包含 0-9 这 10 个类别，每个数字为 28\\*28 像素的灰度图片。其中，训练集有 60000 张图片，测试集有 10000 张图片。假设我们设计了一个可以用来进行图像分类的模型，那么我们可以在 MNIST 数据集上测试该模型的分类能力。\n",
    "\n",
    "![mnist-example](mnist_example.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用 VSQL 模型实现 MNIST 分类\n",
    "\n",
    "### 数据编码\n",
    "\n",
    "在手写数字识别问题中，输入是一张手写数字图片，输出是该图片对应的类别（即数字 0-9）。而由于量子计算机处理的输入是量子态，因此，我们需要将图片编码为量子态。在这里，我们首先使用一个二维矩阵表示一张图片。然后将该矩阵展开为一维向量，并通过补充 0 将向量长度补充到 2 的整数次幂。再对向量进行归一化，即可得到一个量子计算机可以处理的量子态。\n",
    "\n",
    "### VSQL 模型简介\n",
    "\n",
    "变分影子量子学习（variational shadow quantum learning, VSQL）是一个在监督学习框架下的量子–经典混合算法。它使用了参数化量子电路（parameterized quantum circuit, PQC）和经典影子（classical shadow），和通常使用的变分量子算法（variational quantum algorithm, VQA）不同的是，VSQL 只从子空间获取局部特征，而不是从量子态形成的整个希尔伯特空间获取特征。\n",
    "\n",
    "VSQL 的模型原理图如下：\n",
    "\n",
    "![vsql-model](vsql_model.png)\n",
    "\n",
    "VSQL 处理的输入是一个量子态。对于输入的量子态，迭代地作用一个局部参数化量子电路并进行测量，得到局部的影子特征。然后将得到的所有影子特征使用经典神经网络进行计算并得到预测标签。\n",
    "\n",
    "### 工作流\n",
    "\n",
    "根据以上原理，我们只需要使用 MNIST 数据集对 VSQL 模型进行训练。得到收敛后的模型。使用该模型即可进行手写数字的分类。模型的训练流程如下图：\n",
    "\n",
    "\n",
    "![vsql-pipeline](vsql_pipeline_cn.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 如何使用\n",
    "\n",
    "### 使用模型进行预测\n",
    "\n",
    "这里，我们已经给出了一个训练好的模型，可以直接用于 0 和 1 的图片的预测。只需要在 `example.toml` 这个配置文件中进行对应的配置，然后输入命令 `python vsql_classification.py --config example.toml` 即可使用训练好的 VSQL 模型对输入的图片进行测试。\n",
    "\n",
    "### 在线演示\n",
    "\n",
    "这里，我们给出一个在线演示的版本，可以在线进行测试。首先定义配置文件的内容："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_toml = r\"\"\"\n",
    "# 模型的整体配置文件。\n",
    "# 输入当前的任务，可以是 'train' 或者 'test'，分别代表训练和预测。这里我们使用 test，表示我们要进行预测。\n",
    "task = 'test'\n",
    "# 要预测的图片的文件路径。\n",
    "image_path = 'data_0.png'\n",
    "# 上面的图片路径是否是文件夹。对于文件夹路径，我们会对文件夹里面的所有图片文件进行预测。这种方式可以一次测试多个图片。\n",
    "is_dir = false\n",
    "# 训练好的模型参数文件的文件路径。\n",
    "model_path = 'vsql.pdparams'\n",
    "# 量子电路所包含的量子比特的数量。\n",
    "num_qubits = 10\n",
    "# 影子电路所包含的量子比特的数量。\n",
    "num_shadow = 2\n",
    "# 电路深度。\n",
    "depth = 1\n",
    "# 我们要预测的类别。这里我们对 0 和 1 进行分类。\n",
    "classes = [0, 1]\n",
    "\"\"\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来是预测部分的代码："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "对于输入的图片，模型有 89.22% 的信心认为它是 0，和 10.78% 的信心认为它是 1。\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'\n",
    "\n",
    "import toml\n",
    "from paddle_quantum.qml.vsql import train, inference\n",
    "\n",
    "config = toml.loads(test_toml)\n",
    "task = config.pop('task')\n",
    "if task == 'train':\n",
    "    train(**config)\n",
    "elif task == 'test':\n",
    "    prediction, prob = inference(**config)\n",
    "    if config['is_dir']:\n",
    "        print(f\"对输入图片的预测结果分别是 {str(prediction)[1:-1]}。\")\n",
    "    else:\n",
    "        prob = prob[0]\n",
    "        msg = '对于输入的图片，模型有'\n",
    "        for idx, item in enumerate(prob):\n",
    "            if idx == len(prob) - 1:\n",
    "                msg += '和'\n",
    "            label = config['classes'][idx]\n",
    "            msg += f' {item:3.2%} 的信心认为它是 {label:d}'\n",
    "            msg += '。' if idx == len(prob) - 1 else '，'\n",
    "        print(msg)\n",
    "else:\n",
    "    raise ValueError(\"未知的任务，它可以是'train'或'test'。\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这里，我们只需要修改要配置文件中的图片路径，再运行整个代码，就可以快速对其它图片进行测试。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 注意事项\n",
    "\n",
    "我们提供的模型为二分类模型，仅可以用来分辨手写数字 0 和 1。对于其它分类任务，需要重新进行训练。\n",
    "\n",
    "### 数据集结构\n",
    "\n",
    "如果想要使用自定义数据集进行训练，只需要按照规则来准备数据集即可。在数据集文件夹中准备 `train.txt` 和 `test.txt`，如果需要验证集的话还有 `dev.txt`。每个文件里使用一行代表一条数据。每行内容包含图片的文件路径和标签，使用制表符隔开。\n",
    "\n",
    "### 配置文件介绍\n",
    "\n",
    "在 `test.toml` 里有测试所需要的完整的配置文件内容参考。在 `train.toml` 里有训练所需要的完整的配置文件内容参考。使用配置文件的方式即可快速进行模型的训练和预测。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 引用信息\n",
    "\n",
    "```tex\n",
    "@inproceedings{li2021vsql,\n",
    "  title={VSQL: Variational shadow quantum learning for classification},\n",
    "  author={Li, Guangxi and Song, Zhixin and Wang, Xin},\n",
    "  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},\n",
    "  volume={35},\n",
    "  number={9},\n",
    "  pages={8357--8365},\n",
    "  year={2021}\n",
    "}\n",
    "```\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15 (default, Nov 10 2022, 12:46:26) \n[Clang 14.0.6 ]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "49b49097121cb1ab3a8a640b71467d7eda4aacc01fc9ff84d52fcb3bd4007bf1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
